import json
import torch
from torch.utils.data import Dataset
import random
random.seed(42)

class TextPairDataset(Dataset):
    def __init__(self, pos_json_file, neg_json_file, limit=0):
        with open(pos_json_file, "r", encoding="utf-8") as f:
            pos_data = json.load(f)
        with open(neg_json_file, "r", encoding="utf-8") as f:
            neg_data = json.load(f)
        
        if limit > 0:
            half_limit = limit // 2
            pos_data = random.sample(pos_data, min(half_limit, len(pos_data)))
            neg_data = random.sample(neg_data, min(half_limit, len(neg_data)))

        self.data = pos_data + neg_data
        random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item["answer"], item["reason"], item["label"]